package com.chancingpack.shiro;

import com.chancingpack.util.ContextUtil;
import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.spring.web.ShiroFilterFactoryBean;
import org.apache.shiro.web.filter.mgt.FilterChainManager;
import org.apache.shiro.web.filter.mgt.FilterChainResolver;
import org.apache.shiro.web.filter.mgt.PathMatchingFilterChainResolver;
import org.apache.shiro.web.mgt.WebSecurityManager;
import org.apache.shiro.web.servlet.AbstractShiroFilter;
import org.apache.shiro.web.servlet.ShiroHttpServletRequest;
import org.springframework.beans.factory.BeanInitializationException;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

public class MyShiroFilterFactoryBean extends ShiroFilterFactoryBean {

    @Override  
      public Class getObjectType() {  
        return MySpringShiroFilter.class;  
      } 

    @Override
    protected AbstractShiroFilter createInstance() throws Exception {

        SecurityManager securityManager = getSecurityManager();
        if (securityManager == null) {
            String msg = "SecurityManager property must be set.";
            throw new BeanInitializationException(msg);
        }

        if (!(securityManager instanceof WebSecurityManager)) {
            String msg = "The security manager does not implement the WebSecurityManager interface.";
            throw new BeanInitializationException(msg);
        }
        FilterChainManager manager = createFilterChainManager();

        PathMatchingFilterChainResolver chainResolver = new PathMatchingFilterChainResolver();
        chainResolver.setFilterChainManager(manager);

        return new MySpringShiroFilter((WebSecurityManager) securityManager, chainResolver);
    }

    private static final class MySpringShiroFilter extends AbstractShiroFilter {  

        protected MySpringShiroFilter(WebSecurityManager webSecurityManager, FilterChainResolver resolver) {
          super();  
          if (webSecurityManager == null) {  
            throw new IllegalArgumentException("WebSecurityManager property cannot be null.");  
          }  
          setSecurityManager(webSecurityManager);  
          if (resolver != null) {  
            setFilterChainResolver(resolver);  
          }

        }
        @Override
        protected void executeChain(ServletRequest request, ServletResponse response, FilterChain origChain) throws IOException, ServletException {
            FilterChain chain = this.getExecutionChain(request, response, origChain);
            if(request instanceof  HttpServletRequest){
                HttpServletRequest httpServletRequest=(HttpServletRequest)request;
                ContextUtil.setSessionId(httpServletRequest.getSession().getId());
            }
            chain.doFilter(request, response);
        }
        @Override  
        protected ServletResponse wrapServletResponse(HttpServletResponse orig, ShiroHttpServletRequest request) {
          return new MyShiroHttpServletResponse(orig, getServletContext(), request);  
        }  
    }
}